{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0da10bde",
   "metadata": {},
   "source": [
    "# Running a Simple CNN Inference on CIFAR-10\n",
    "\n",
    "This tutorial demonstrates how to use **TT-NN** to perform inference with a simple Convolutional Neural Network (CNN) on the CIFAR-10 dataset.\n",
    "\n",
    "We will:\n",
    "\n",
    "- Load the CIFAR-10 dataset.\n",
    "- Define a simple CNN using TT-NN operations.\n",
    "- Run inference on sample images.\n",
    "- Observe outputs and accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "233d6aa4",
   "metadata": {},
   "source": [
    "## Setup and Imports\n",
    "\n",
    "In this script, several libraries are imported to support image classification using a simple CNN on the CIFAR-10 dataset. The OS module checks if pretrained weight files exist on disk. Torch loads model weights, torchvision and its transforms submodule downloads the CIFAR-10 dataset and applies preprocessing, converting images to tensors and normalizing pixel values for example. The TT-NN library is Tenstorrent's neural network API, responsible for interfacing with Tenstorrent hardware. TT-NN performs operations like convolution, pooling, activation, linear layers, data layout management, and type conversions between PyTorch and TT-NN formats. Finally, loguru logs messages and debugging output, providing insights into model operations and predictions throughout the inference process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "519491bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import ttnn\n",
    "from loguru import logger"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87cab60f",
   "metadata": {},
   "source": [
    "## Open the Device\n",
    "\n",
    "Create the device to run the program with custom L1 memory config. The following parameter allocates on-chip L1 memory for sliding-window operations like convolutions, and other kernels that need quick, scratchpad-like memory: `l1_small_size`.  8 kB is enough for simple CNNS, complex models require up to 32 kB or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb980bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = ttnn.open_device(device_id=0, l1_small_size=8192)\n",
    "logger.info(\"\\n--- Simple CNN Inference Using TT-NN on CIFAR-10 ---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc0e6b24",
   "metadata": {},
   "source": [
    "## Load the CIFAR-10 Dataset\n",
    "\n",
    "Normalize images and load the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d8d1d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define input transforms: Convert to tensor and normalize\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "# Load CIFAR-10 test data\n",
    "testset = torchvision.datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31ad440d",
   "metadata": {},
   "source": [
    "## Load or Initialize Weights\n",
    "\n",
    "Optimally, pretrained weights are loaded and used for the model, but in case the weights file is not found, default to random values which will likely yield poor results. Run\n",
    "the provided `train_and_export_cnn.py` script to generate weights to a file named `simple_cnn_cifar10_weights.pt`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6634a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(\"simple_cnn_cifar10_weights.pt\"):\n",
    "    weights = torch.load(\"simple_cnn_cifar10_weights.pt\")\n",
    "    weights = {\n",
    "        k: ttnn.from_torch(v, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device)\n",
    "        for k, v in weights.items()\n",
    "    }\n",
    "    logger.info(\"Loaded pretrained weights\")\n",
    "else:\n",
    "    logger.warning(\"Weights not found, using random weights\")\n",
    "    torch.manual_seed(0)\n",
    "    weights = {\n",
    "        \"conv1.weight\": ttnn.rand((16, 3, 3, 3), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"conv1.bias\": ttnn.rand((16,), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"conv2.weight\": ttnn.rand((32, 16, 3, 3), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"conv2.bias\": ttnn.rand((32,), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"fc1.weight\": ttnn.rand((128, 2048), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"fc1.bias\": ttnn.rand((128,), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"fc2.weight\": ttnn.rand((10, 128), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "        \"fc2.bias\": ttnn.rand((10,), layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device),\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28c50dbd",
   "metadata": {},
   "source": [
    "## Define Convolution and Pooling Stage\n",
    "\n",
    "The function, `conv_pool_stage`, encapsulates a typical convolutional neural network stage where an input tensor undergoes a 2D convolution followed by an activation and a max pooling operation, all using Tenstorrent's TT-NN API. It accepts an input tensor in NHWC layout, along with metadata like shape, number of output channels, references to specific weight and bias tensors, activation type (e.g., ReLU), and the target hardware device. First, it extracts the appropriate weight and bias tensors from the given dictionary and reshapes the bias to a broadcastable shape. It defines convolution parameters—kernel sizegur, stride, and padding. It sets up a TT-NN specific configuration including the activation function. If enabled, it logs details like tensor shapes and convolution parameters for debugging the first sample. The convolution is then performed using `ttnn.conv2d`, followed by a max pooling operation configured with standard 2×2 kernel and stride values. Again, if logging is enabled, pooling parameters and resulting tensor shapes are recorded. Finally, the resulting TT tensor after max pooling is returned for use in the next stage of the network. This function modularizes a common pattern in CNNs and provides flexibility for different layers and debug logging. \n",
    "\n",
    "For more information on convolution functions see: [ttnn.Conv2dConfig](../../api/ttnn.Conv2dConfig.rst)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cd4403c",
   "metadata": {},
   "outputs": [],
   "source": "def conv_pool_stage(\n    input_tensor: ttnn.Tensor,\n    input_NHWC: ttnn.Shape,\n    conv_outchannels: int,\n    weights: dict,\n    weight_str: str,\n    bias_str: str,\n    activation: ttnn.UnaryWithParam,\n    device: ttnn.Device,\n    log_first_sample: bool = False,\n) -> ttnn.Tensor:\n    \"\"\"\n    Perform convolution + activation + max pooling using TT-NN.\n    Args:\n        input_tensor: Input TT tensor in NHWC format.\n        input_NHWC: Tuple representing (Batch, Height, Width, Channels) of the input tensor.\n        conv_outchannels: Number of output channels for the convolution layer.\n        weights: Dictionary containing model weights and biases.\n        weight_str: Key name for convolution weights in the weights dict.\n        bias_str: Key name for convolution biases in the weights dict.\n        activation: Activation function as UnaryWithParam to apply after conv.\n        device: Target TT device to execute the operations on.\n        log_first_sample: Whether to log detailed info (used for debugging first sample).\n    Returns:\n        Output tensor after conv + max pooling (TT format).\n    \"\"\"\n    # Extract weight and bias tensors from weights dictionary\n    W = weights[weight_str]\n    B = weights[bias_str]\n    B = ttnn.reshape(B, (1, 1, 1, -1))  # Ensure bias is in correct shape for TT-NN\n\n    # Define convolution parameters\n    conv_kernel_size = (3, 3)\n    conv_stride = (1, 1)\n    conv_padding = (1, 1)\n\n    # Set up TT-NN convolution configuration including activation function\n    conv_config = ttnn.Conv2dConfig(weights_dtype=ttnn.bfloat16, activation=activation)\n\n    # Optional detailed logging for the first sample (shape, config, etc.)\n    if log_first_sample:\n        logger.info(\"=====================================================================\")\n        logger.info(\"Input parameters to conv2d:\")\n        logger.info(f\"  input_tensor shape: {input_tensor.shape}\")\n        logger.info(f\"  weight_tensor shape: {W.shape}\")\n        logger.info(f\"  bias_tensor shape: {B.shape}\")\n        logger.info(f\"  in_channels: {input_NHWC[3]}\")\n        logger.info(f\"  out_channels: {conv_outchannels}\")\n        logger.info(f\"  device: {device}\")\n        logger.info(f\"  kernel_size: {conv_kernel_size}\")\n        logger.info(f\"  stride: {conv_stride}\")\n        logger.info(f\"  padding: {conv_padding}\")\n        logger.info(f\"  batch_size: {input_NHWC[0]}\")\n        logger.info(f\"  input_height: {input_NHWC[1]}\")\n        logger.info(f\"  input_width: {input_NHWC[2]}\")\n        logger.info(f\"  conv_config: {conv_config}\")\n        logger.info(f\"  groups: {0}\")\n\n    # Perform convolution\n    conv1_out = ttnn.conv2d(\n        input_tensor=input_tensor,\n        weight_tensor=W,\n        bias_tensor=B,\n        in_channels=input_NHWC[3],\n        out_channels=conv_outchannels,\n        device=device,\n        kernel_size=conv_kernel_size,\n        stride=conv_stride,\n        padding=conv_padding,\n        batch_size=input_NHWC[0],\n        input_height=input_NHWC[1],\n        input_width=input_NHWC[2],\n        conv_config=conv_config,\n        groups=0,\n    )\n\n    # Define max pooling parameters\n    max_pool2d_kernel_size = [2, 2]\n    max_pool2d_stride = [2, 2]\n    max_pool2d_padding = [0, 0]\n    max_pool2d_dilation = [1, 1]\n\n    # Optional logging for max pooling input and parameters\n    if log_first_sample:\n        logger.info(\"Input parameters to max_pool2d:\")\n        logger.info(f\"  input shape: {conv1_out.shape}\")\n        logger.info(f\"  batch_size: {input_NHWC[0]}\")\n        logger.info(f\"  input_h: {input_NHWC[1]}\")\n        logger.info(f\"  input_w: {input_NHWC[2]}\")\n        logger.info(f\"  channels: {conv_outchannels}\")\n        logger.info(f\"  kernel_size: {max_pool2d_kernel_size}\")\n        logger.info(f\"  stride: {max_pool2d_stride}\")\n        logger.info(f\"  padding: {max_pool2d_padding}\")\n        logger.info(f\"  dilation: {max_pool2d_dilation}\")\n        logger.info(f\"  ceil_mode: {False}\")\n\n    # Perform max pooling\n    max_pool2d_out = ttnn.max_pool2d(\n        conv1_out,\n        batch_size=input_NHWC[0],\n        input_h=input_NHWC[1],\n        input_w=input_NHWC[2],\n        channels=conv_outchannels,\n        kernel_size=max_pool2d_kernel_size,\n        stride=max_pool2d_stride,\n        padding=max_pool2d_padding,\n        dilation=max_pool2d_dilation,\n        ceil_mode=False,\n    )\n\n    # Log output shape after pooling\n    if log_first_sample:\n        logger.info(f\"max_pool2d output shape: {max_pool2d_out.shape}\")\n        logger.info(\"=====================================================================\")\n\n    return max_pool2d_out"
  },
  {
   "cell_type": "markdown",
   "id": "b9604d5d",
   "metadata": {},
   "source": [
    "## Run Inference on Test Samples\n",
    "\n",
    "This code sample performs inference on the first five test samples from the CIFAR-10 dataset using a simple convolutional neural network (SimpleCNN) running on Tenstorrent hardware via the TT-NN API. It initializes counters tracking correct predictions and total samples processed. For each sample, it converts the input image from a PyTorch tensor to a TT-NN tensor, rearranging its layout from NCHW to NHWC format. The image is then passed through two convolution and pooling stages using the `conv_pool_stage` function. The output is flattened and passed through two fully connected layers (FC1 and FC2), with ReLU applied after FC1. The weights and biases for these layers are converted to the appropriate TT-NN format with tiling and transposing as needed. After obtaining the final logits from FC2, the output is converted back to a PyTorch tensor, and the predicted label is determined by taking the index of the highest logit. The prediction is compared to the true label to update the accuracy counters, and the result for each sample is logged. Finally, the overall inference accuracy is printed after processing the five samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef5ea5f",
   "metadata": {},
   "outputs": [],
   "source": "correct = 0\ntotal = 0\n\n# Run inference on a few test samples\nfor i, (image, label) in enumerate(testloader):\n    if i >= 5:\n        break\n\n    # Convert image to TT tensor\n    ttnn_image = ttnn.from_torch(image, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, device=device)\n    ttnn_image_permuated = ttnn.permute(ttnn_image, (0, 2, 3, 1))  # NCHW -> NHWC\n\n    # Only log details for first sample\n    log_this = i == 0\n\n    # Apply first conv + pool stage\n    conv1_pool = conv_pool_stage(\n        ttnn_image_permuated,\n        ttnn_image_permuated.shape,\n        16,\n        weights,\n        \"conv1.weight\",\n        \"conv1.bias\",\n        ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU),\n        device,\n        log_first_sample=log_this,\n    )\n\n    # Apply second conv + pool stage\n    conv2_pool = conv_pool_stage(\n        conv1_pool,\n        (1, 16, 16, 16),\n        32,\n        weights,\n        \"conv2.weight\",\n        \"conv2.bias\",\n        ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU),\n        device,\n        log_first_sample=log_this,\n    )\n\n    # Flatten for FC layers\n    B, H, W, C = conv2_pool.shape\n    out_flat = ttnn.to_torch(conv2_pool)  # Convert back to torch\n    out_flat = out_flat.permute(0, 3, 1, 2).contiguous().view(B, -1)  # NHWC -> NCHW -> Flatten\n\n    # Prepare fully connected layers\n    W3 = weights[\"fc1.weight\"]\n    B3 = weights[\"fc1.bias\"].reshape((1, -1))  # Reshape bias for broadcast compatibility\n    W4 = weights[\"fc2.weight\"]\n    B4 = weights[\"fc2.bias\"]\n\n    # Convert to TT format for FC1\n    W3_tt = ttnn.to_layout(ttnn.transpose(W3, 0, 1), ttnn.TILE_LAYOUT)\n    B3_tt = ttnn.to_layout(B3.reshape((1, -1)), ttnn.TILE_LAYOUT)\n\n    # Convert input to TT format\n    x_tt = ttnn.from_torch(out_flat, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n\n    # Apply FC1 + ReLU\n    out = ttnn.linear(x_tt, W3_tt, bias=B3_tt)\n    out = ttnn.relu(out)\n\n    # Convert to TT format for FC2\n    W4_tt = ttnn.to_layout(ttnn.transpose(W4, 0, 1), ttnn.TILE_LAYOUT)\n    B4_tt = ttnn.to_layout(B4.reshape((1, -1)), ttnn.TILE_LAYOUT)\n\n    # Apply FC2 (output logits)\n    out = ttnn.linear(out, W4_tt, bias=B4_tt)\n\n    # Convert prediction back to torch\n    prediction = ttnn.to_torch(out)\n    predicted_label = torch.argmax(prediction, dim=1).item()\n    correct += predicted_label == label.item()\n    total += 1\n\n    logger.info(f\"Sample {i+1}: Predicted={predicted_label}, Actual={label.item()}\")\n\nlogger.info(f\"\\nTT-NN SimpleCNN Inference Accuracy: {correct}/{total} = {100.0 * correct / total:.2f}%\")"
  },
  {
   "cell_type": "markdown",
   "id": "728d82f7",
   "metadata": {},
   "source": [
    "## Close the Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e8d4e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttnn.close_device(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f125bcbf",
   "metadata": {},
   "source": [
    "We have built and run a simple CNN using Tenstorrent's TT-NN library on the CIFAR-10 dataset, observed predictions, and computed accuracy on a few samples.\n",
    "\n",
    "For full-scale inference or training, pre-trained weights should be used, and additional optimization strategies may be applied."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f134401",
   "metadata": {},
   "source": [
    "## Full Example and Output\n",
    "\n",
    "Lets put everything together in a complete example that can be run\n",
    "directly.\n",
    "\n",
    "[ttnn_simplecnn_inference.py](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/tutorials/basic_python/ttnn_simplecnn_inference.py)\n",
    "\n",
    "Running this script will generate the following output:\n",
    "\n",
    "``` console\n",
    "$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_simplecnn_inference.py\n",
    "2025-07-07 13:10:17.041 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.043 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.050 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:190)\n",
    "2025-07-07 13:10:17.050 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.051 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.057 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.058 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.064 | info     |   SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)\n",
    "2025-07-07 13:10:17.161 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:10:17.224 | info     |   SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)\n",
    "2025-07-07 13:10:17.235 | info     |   SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)\n",
    "2025-07-07 13:10:17.321 | info     |           Metal | AI CLK for device 0 is:   1000 MHz (metal_context.cpp:128)\n",
    "2025-07-07 13:10:17.889 | info     |           Metal | Initializing device 0. Program cache is enabled (device.cpp:428)\n",
    "2025-07-07 13:10:17.891 | warning  |           Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)\n",
    "2025-07-07 13:10:19.734 | INFO     | __main__:main:15 - \n",
    "--- Simple CNN Inference Using TT-NN on CIFAR-10 ---\n",
    "Files already downloaded and verified\n",
    "2025-07-07 13:10:20.471 | INFO     | __main__:main:30 - Loaded pretrained weights\n",
    "2025-07-07 13:10:21.075 | INFO     | __main__:conv_pool_stage:86 - =====================================================================\n",
    "2025-07-07 13:10:21.075 | INFO     | __main__:conv_pool_stage:87 - Input parameters to conv2d:\n",
    "2025-07-07 13:10:21.075 | INFO     | __main__:conv_pool_stage:88 -   input_tensor shape: Shape([1, 32, 32, 3])\n",
    "2025-07-07 13:10:21.075 | INFO     | __main__:conv_pool_stage:89 -   weight_tensor shape: Shape([16, 3, 3, 3])\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:90 -   bias_tensor shape: Shape([1, 1, 1, 16])\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:91 -   in_channels: 3\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:92 -   out_channels: 16\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:93 -   device: MeshDevice(1x1 grid, 1 devices)\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:94 -   kernel_size: (3, 3)\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:95 -   stride: (1, 1)\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:96 -   padding: (1, 1)\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:97 -   batch_size: 1\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:98 -   input_height: 32\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:99 -   input_width: 32\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:100 -   conv_config: Conv2dConfig(weights_dtype=DataType::BFLOAT16,activation=relu,deallocate_activation=0,reallocate_halo_output=1,act_block_h_override=0,act_block_w_div=1,reshard_if_not_optimal=0,override_sharding_config=0,shard_layout=std::nullopt,core_grid=std::nullopt,transpose_shards=0,output_layout=Layout::TILE,enable_act_double_buffer=0,enable_weights_double_buffer=0,enable_split_reader=0,enable_subblock_padding=0,in_place=0,enable_kernel_stride_folding=0)\n",
    "2025-07-07 13:10:21.076 | INFO     | __main__:conv_pool_stage:101 -   groups: 0\n",
    "2025-07-07 13:10:22.960 | INFO     | __main__:conv_pool_stage:129 - Input parameters to max_pool2d:\n",
    "2025-07-07 13:10:22.960 | INFO     | __main__:conv_pool_stage:130 -   input shape: Shape([1, 1, 1024, 16])\n",
    "2025-07-07 13:10:22.960 | INFO     | __main__:conv_pool_stage:131 -   batch_size: 1\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:132 -   input_h: 32\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:133 -   input_w: 32\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:134 -   channels: 16\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:135 -   kernel_size: [2, 2]\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:136 -   stride: [2, 2]\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:137 -   padding: [0, 0]\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:138 -   dilation: [1, 1]\n",
    "2025-07-07 13:10:22.961 | INFO     | __main__:conv_pool_stage:139 -   ceil_mode: False\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:157 - max_pool2d output shape: Shape([1, 1, 256, 32])\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:158 - =====================================================================\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:86 - =====================================================================\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:87 - Input parameters to conv2d:\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:88 -   input_tensor shape: Shape([1, 1, 256, 32])\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:89 -   weight_tensor shape: Shape([32, 16, 3, 3])\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:90 -   bias_tensor shape: Shape([1, 1, 1, 32])\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:91 -   in_channels: 16\n",
    "2025-07-07 13:10:24.026 | INFO     | __main__:conv_pool_stage:92 -   out_channels: 32\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:93 -   device: MeshDevice(1x1 grid, 1 devices)\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:94 -   kernel_size: (3, 3)\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:95 -   stride: (1, 1)\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:96 -   padding: (1, 1)\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:97 -   batch_size: 1\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:98 -   input_height: 16\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:99 -   input_width: 16\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:100 -   conv_config: Conv2dConfig(weights_dtype=DataType::BFLOAT16,activation=relu,deallocate_activation=0,reallocate_halo_output=1,act_block_h_override=0,act_block_w_div=1,reshard_if_not_optimal=0,override_sharding_config=0,shard_layout=std::nullopt,core_grid=std::nullopt,transpose_shards=0,output_layout=Layout::TILE,enable_act_double_buffer=0,enable_weights_double_buffer=0,enable_split_reader=0,enable_subblock_padding=0,in_place=0,enable_kernel_stride_folding=0)\n",
    "2025-07-07 13:10:24.027 | INFO     | __main__:conv_pool_stage:101 -   groups: 0\n",
    "2025-07-07 13:10:25.120 | INFO     | __main__:conv_pool_stage:129 - Input parameters to max_pool2d:\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:130 -   input shape: Shape([1, 1, 256, 32])\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:131 -   batch_size: 1\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:132 -   input_h: 16\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:133 -   input_w: 16\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:134 -   channels: 32\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:135 -   kernel_size: [2, 2]\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:136 -   stride: [2, 2]\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:137 -   padding: [0, 0]\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:138 -   dilation: [1, 1]\n",
    "2025-07-07 13:10:25.121 | INFO     | __main__:conv_pool_stage:139 -   ceil_mode: False\n",
    "2025-07-07 13:10:25.669 | INFO     | __main__:conv_pool_stage:157 - max_pool2d output shape: Shape([1, 1, 64, 32])\n",
    "2025-07-07 13:10:25.669 | INFO     | __main__:conv_pool_stage:158 - =====================================================================\n",
    "2025-07-07 13:10:30.120 | INFO     | __main__:main:238 - Sample 1: Predicted=8, Actual=3\n",
    "2025-07-07 13:10:30.136 | INFO     | __main__:main:238 - Sample 2: Predicted=8, Actual=8\n",
    "2025-07-07 13:10:30.151 | INFO     | __main__:main:238 - Sample 3: Predicted=8, Actual=8\n",
    "2025-07-07 13:10:30.166 | INFO     | __main__:main:238 - Sample 4: Predicted=0, Actual=0\n",
    "2025-07-07 13:10:30.181 | INFO     | __main__:main:238 - Sample 5: Predicted=6, Actual=6\n",
    "2025-07-07 13:10:30.181 | INFO     | __main__:main:240 - \n",
    "TT-NN SimpleCNN Inference Accuracy: 4/5 = 80.00%\n",
    "2025-07-07 13:10:30.181 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "2025-07-07 13:10:30.182 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)\n",
    "2025-07-07 13:10:30.182 | info     |           Metal | Closing device 0 (device.cpp:468)\n",
    "2025-07-07 13:10:30.182 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:783)\n",
    "2025-07-07 13:10:30.183 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddc54da2",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
